{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Composition generation\n",
    "Here, we generate a set of quaternary oxide compositions using a modified `smact_filter` function and then turn the results into a dataframe with features that can be read by a machine learning algorithm.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Imports ###\n",
    "import smact\n",
    "from smact import screening\n",
    "from itertools import combinations, product\n",
    "import multiprocessing\n",
    "from pymatgen import Composition\n",
    "import pandas as pd\n",
    "from matminer.featurizers.conversions import StrToComposition\n",
    "from matminer.featurizers.base import MultipleFeaturizer\n",
    "from matminer.featurizers import composition as cf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the elements we are interested in\n",
    "all_el = smact.element_dictionary()\n",
    "symbol_list = [k for k,i in all_el.items()]\n",
    "do_not_want = ['H', 'He', 'B', 'C', 'O', 'Ne', 'Ar', 'Kr', 'Tc', 'Xe', 'Rn',\n",
    "              'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', \n",
    "              'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr',\n",
    "              'Ra', 'Fr', 'At', 'Po', 'Pm', 'Eu', 'Tb', 'Yb']\n",
    "good_elements = [all_el[x] for x in symbol_list if x not in do_not_want]\n",
    "\n",
    "all_el_combos = combinations(good_elements,3)\n",
    "\n",
    "def smact_filter(els):\n",
    "    all_compounds = []\n",
    "    elements = [e.symbol for e in els] + ['O']\n",
    "    \n",
    "    # Get Pauling electronegativities\n",
    "    paul_a, paul_b, paul_c = els[0].pauling_eneg, els[1].pauling_eneg, els[2].pauling_eneg\n",
    "    electronegativities = [paul_a, paul_b, paul_c, 3.44]\n",
    "    \n",
    "    # For each set of species (in oxidation states) apply both SMACT tests\n",
    "    for ox_a, ox_b, ox_c in product(els[0].oxidation_states, \n",
    "                    els[1].oxidation_states, els[2].oxidation_states):      \n",
    "        ox_states = [ox_a, ox_b, ox_c, -2]\n",
    "        # Test for charge balance\n",
    "        cn_e, cn_r = smact.neutral_ratios(ox_states, threshold = 8)\n",
    "        if cn_e:\n",
    "            # Electronegativity test\n",
    "            electroneg_OK = screening.pauling_test(ox_states, electronegativities)\n",
    "            if electroneg_OK:\n",
    "                compound = tuple([elements,cn_r[0]])\n",
    "                all_compounds.append(compound)\n",
    "    return all_compounds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Multiprocessing is used to speed things up (generation of all compositions takes ~40 minutes on a 4GHz Intel core i7 iMac)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of compositions: 3217181\n"
     ]
    }
   ],
   "source": [
    "with multiprocessing.Pool() as p:\n",
    "    result = p.map(smact_filter, all_el_combos)\n",
    "    \n",
    "flat_list = [item for sublist in result for item in sublist]\n",
    "print(\"Number of compositions: {0}\".format(len(flat_list)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We turn our generated compositions into pretty formulas, again using multiprocessing. There should be ~1.1M unique formulas. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of unique compositions formulas: 1118505\n"
     ]
    }
   ],
   "source": [
    "def comp_maker(comp):\n",
    "    form = []\n",
    "    for el, ammt in zip(comp[0], comp[1]):\n",
    "        form.append(el)\n",
    "        form.append(ammt)\n",
    "    form = ''.join(str(e) for e in form)\n",
    "    pmg_form = Composition(form).reduced_formula\n",
    "    return pmg_form\n",
    "\n",
    "with multiprocessing.Pool() as p:\n",
    "    pretty_formulas = p.map(comp_maker, flat_list)\n",
    "    \n",
    "unique_pretty_formulas = list(set(pretty_formulas))\n",
    "print(\"Number of unique compositions formulas: {0}\".format(len(unique_pretty_formulas)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pretty_formula</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>1118505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>unique</th>\n",
       "      <td>1118505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>top</th>\n",
       "      <td>Sm2VBrO5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>freq</th>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       pretty_formula\n",
       "count         1118505\n",
       "unique        1118505\n",
       "top          Sm2VBrO5\n",
       "freq                1"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_data = pd.DataFrame(unique_pretty_formulas).rename(columns={0: 'pretty_formula'})\n",
    "new_data = new_data.drop_duplicates(subset = 'pretty_formula')\n",
    "new_data.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9411c015c4947e8a11cb5839d184eb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='StrToComposition', max=1118505, style=ProgressStyle(descripti…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6ed5582a5f4c4716aaa28b6208b1c318",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='MultipleFeaturizer', max=1118505, style=ProgressStyle(descrip…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Add descriptor columns\n",
    "# this will take a little time as we have over 1 million rows\n",
    "str_to_comp = StrToComposition(target_col_id='composition_obj')\n",
    "str_to_comp.featurize_dataframe(new_data, col_id='pretty_formula')\n",
    "\n",
    "feature_calculators = MultipleFeaturizer([cf.Stoichiometry(), \n",
    "                                          cf.ElementProperty.from_preset(\"magpie\"),\n",
    "                                          cf.ValenceOrbital(props=['avg']), \n",
    "                                          cf.IonProperty(fast=True),\n",
    "                                          cf.BandCenter(), cf.AtomicOrbitals()])\n",
    "\n",
    "feature_labels = feature_calculators.feature_labels()\n",
    "feature_calculators.featurize_dataframe(new_data, col_id='composition_obj');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save as .csv file \n",
    "new_data.to_csv('All_oxide_comps_dataframe_featurized.csv', chunksize=10000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
